import logging, Memory, Clock

class DataCache():

    # CDS: A Pythonic enum http://stackoverflow.com/a/1695250
    def enum(self, *sequential, **named):
        enums = dict(zip(sequential, range(len(sequential))), **named)
        return type('Enum', (), enums)

    def __str__(self):
        return 'DataCache'

    def __init__(self, size, blockSize, associativity, 
                 policy, allocate, cachePenalty, memPenalty):

        # CDS: Cache commands
        # 0 means read data, 1 write data, 2 instruction fetch
        self.cmds = self.enum('RD', 'WR', 'IF')

        # Copy arguments into data members
        self.size = size
        self.blockSize = blockSize
        self.assoc = associativity
        self.policy = policy
        self.allocate = allocate
        self.cachePenalty = cachePenalty
        self.memPenalty = memPenalty

        # Cache latencies
        self.readPenalty = self.cachePenalty
        self.writePenalty = self.readPenalty
        self.ifPenalty = self.readPenalty

        # CPU clock
        self.clock = Clock.Clock()

        self.nBlocks = self.size*1024/self.blockSize   # block size in bytes.
        self.numLines = self.nBlocks/self.assoc   # Number of cache lines

        self.blocks = [None] * self.nBlocks # Cache data memory.
        
        self.LRUcount = [0] * self.nBlocks # LRU block replacement policy
        self.dirty = [False] * self.nBlocks # Dirty bits for copyback policy

        # The next memory level can be RAM or L2 cache, here we have RAM
        self.dataStore = Memory.RAM(self.clock, 
                                    readPenalty=self.memPenalty, 
                                    writePenalty=self.memPenalty)
        # Logging
        self.log = logging.getLogger(str(self))
        self.log.debug('block size: %d' % self.blockSize)
        self.log.debug('block count: %d' % self.nBlocks)
        self.log.debug('associativity: %s' % associativity)
        self.log.debug('write policy: %s' % policy)
        self.log.debug('allocate: %s' % allocate)
        self.log.debug('cache access penalty (cycles): %d' % self.readPenalty)
        self.log.debug('memory access penalty (cycles): %d' % memPenalty)

    def getLineForAddress(self, address):
        return (address/self.blockSize) % self.numLines

    def getTagForAddress(self, address):
        return address/(self.blockSize*self.numLines)

    def isInCache(self, address):

        linNum = self.getLineForAddress(address)
        tag = self.getTagForAddress(address)

        linStart = linNum*self.assoc
        linEnd = linStart+self.assoc

        for i, elem in enumerate(self.blocks[linStart:linEnd]):
            if elem == tag: # Cache hit
                self.LRUcount[linStart+i] = 0
                return True
            else: # Cache miss
                # Lukas: Maybe move this out of the else block?
                self.LRUcount[linStart+i] += 1 # block data gets 1 cycle older

        return False 

    def doCacheRead(self, address):
        '''Read from cache. Incur cache access penalty.'''

        self.clock.advance(self.readPenalty)
        return self.readPenalty

    def enforceWritePolicy(self, address):
        linNum = self.getLineForAddress(address)

        if self.policy == "writethrough":
            self.doMemoryWrite(address)
        elif self.policy == "writeback":
            self.dirty[linNum] = True

    def doDirtyBitCheck(self, address):
        linNum = self.getLineForAddress(address)

        if self.policy == 'writeback' and self.dirty[linNum]:
            self.doMemoryWrite(address)
            self.dirty[linNum] = False

    def doCacheInsertion(self, address):
        '''Write to cache. If write-through, write to memory.
        If write-back, mark address dirty.'''

        linNum = self.getLineForAddress(address)
        tag = self.getTagForAddress(address)

        linStart = linNum*self.assoc
        linEnd = linStart+self.assoc

        # Write to memory or mark dirty
        self.enforceWritePolicy(address)

        # Insert into the first empty block
        for i, elem in enumerate(self.blocks[linStart:linEnd]):
            if elem is None:
                self.blocks[linStart+i] = tag
                return

        # CDS: No empty blocks. Replace a block
        # For directly mapped caches simply replace the mapped line
        if self.assoc == 1:
            self.blocks[linNum] = tag

        else: # For associative caches we use LRU (least-recently-used)
            
            # Get the oldest block
            oldestBlock = self.blocks.index(max(self.blocks))

            # Before eviction, write-back if that is the policy
            self.doDirtyBitCheck(address)

            # Evict and replace
            self.blocks[oldestBlock] = tag

    def doMemoryRead(self, address):
        '''Access lower level memory. Incurs access penalty.'''

        # First check dirty bit and write back if so
        self.doDirtyBitCheck(address)

        # If we were actually passing data we would do so here
        self.dataStore.read(address)
        self.doCacheInsertion(address)

    def doMemoryWrite(self, address):
        self.dataStore.write(address)

    def read(self, address):
        '''Get a word from the cache. Returns true if hit, else false.'''

        prevCycle = self.clock.currentCycle
        hit = self.isInCache(address)

        if hit: self.doCacheRead(address)
        else: self.doMemoryRead(address)

        cost = self.clock.currentCycle - prevCycle
        return hit, cost

    def write(self, address):
        '''Write a word to the cache. 
        Returns true if cache hit, false if cache miss.'''

        prevCycle = self.clock.currentCycle
        hit = self.isInCache(address)
        
        # If we have a write miss and
        # the policy is write-allocate
        if not hit: 
            self.doMemoryRead(address)
            #if self.allocate: 
            #    pass # TODO

        cost = self.clock.currentCycle - prevCycle
        return hit, cost

    def instructionFetch(self, address):
        '''Get an instruction from cache. 
        Returns true if cache hit, false if cache miss.'''

        prevCycle = self.clock.currentCycle
        hit = self.isInCache(address)

        if hit: self.doCacheRead(address)
        else: self.doMemoryRead(address)

        cost = self.clock.currentCycle - prevCycle
        return hit, cost

class L1DataCache():
    '''A simulation of a CPU L1 data cache'''

    def __init__():
      self.dataStore = L2DataCache()
      self.cachePenalty = 1

class L2DataCache():
    '''A simulation of a CPU L2 data cache'''

    def __init__():
        self.dataStore = L3DataCache()
        self.cachePenalty = 2

class L3DataCache():
    '''A simulation of a CPU L3 data cache'''

    def __init__():
        self.cachePenalty = 3

class InstructionCache():
    '''A simulation of a CPU instruction cache'''